This document demonstrates the use of the weightedRF and weightedLASSO functions for integrative GRN inference.

It shows how to select one optimal alpha value per gene, in order to apply data integration only to the target genes for which expression prediction error is reduced by using PWM prior information, significantly more than in a baseline model.

Data import

Import of the expression data and the N-responsive genes and regulators :

load('rdata/inference_input_N_response_varala.rdata')
genes <- input_data$grouped_genes; length(genes)
## [1] 1426
tfs <- input_data$grouped_regressors; length(tfs)
## [1] 201
counts <- input_data$counts; dim(counts)
## [1] 1426   45
load("rdata/pwm_occurrences_N_response_varala.rdata")
dim(pwm_occurrence)
## [1] 1426  201
pwm_imputed <- pwm_occurrence
pwm_imputed[is.na(pwm_imputed)] <- 0.5

ALPHAS <- seq(0,1, by = 0.1)

Prediction error as a function of data integration strengh (alpha)

The following steps are time and CPU intensive, so the result files can just be loaded to be analysed in further steps.

For weightedRF

Generating 100 repetitions of weightedRF MSE estimation on true data, and on shuffled data.

lmses <- data.frame(row.names = genes)
N <-100
for(alpha in ALPHAS){
  for(perm in 1:N){
    lmses[,paste(as.character(alpha), perm, "true_data")] <- weightedRF_inference_MSE(counts, genes, tfs, alpha = alpha, nTrees = 2000,
                             pwm_occurrence = pwm_occurrence, nCores = 45, tf_expression_permutation = F)
  }

  for(perm in 1:N){
    lmses[,paste(as.character(alpha), perm, "shuffled")] <- weightedRF_inference_MSE(counts, genes, tfs, alpha = alpha, nTrees = 2000,
                           pwm_occurrence = pwm_occurrence, nCores = 45, tf_expression_permutation = T)
  }
}

save(lmses, file = "results/brf_perumtations_mse_all_genes_predict.rdata")
subset<-unique(rownames(lmses))

For weightedLASSO

Generating 100 repetitions of weightedLASSO MSE estimation on true data, and on shuffled data.

# lmses <- data.frame(row.names = genes)
# N<-50
# for(alpha in ALPHAS){
#   for(perm in 1:N){
#     lmses[,paste(as.character(alpha), perm, "true_data")] <- weightedLASSO_inference_MSE(counts, genes, tfs, alpha = alpha, N=50,
#                              pwm_occurrence = pwm_occurrence, nCores = 40, tf_expression_permutation = F)
# 
#     lmses[,paste(as.character(alpha), perm, "shuffled")] <- weightedLASSO_inference_MSE(counts, genes, tfs, alpha = alpha, N=50,
#                            pwm_occurrence = pwm_occurrence, nCores = 40, tf_expression_permutation = T)
#   }
# }
# save(lmses, file = "results/lasso_perumtations_mse_all_genes_test.rdata")

Getting the importance metrics of inferred edges in order to measure effective data integration and build inferred GRNs

For weightedRF

nCores = 45
mats <- list()
nrep <- 100
for(alpha in ALPHAS){ # exploring PWM integration strength
  for(rep in 1:nrep){ # exploring inherent variability
    
    mat_rf <- weightedRF_inference(counts, genes, tfs, nTrees = 2000,
                            alpha = alpha,
                            pwm_occurrence = pwm_occurrence,
                            nCores = nCores,
                            importance = "%IncMSE")
    
    mat_rf_perm <- weightedRF_inference(counts, genes, tfs, nTrees = 2000,
                            alpha = alpha, tf_expression_permutation = TRUE,
                            pwm_occurrence = pwm_occurrence,
                            nCores = nCores,
                            importance = "%IncMSE")
    
    mats[[paste0("bRF_", as.character(alpha),  '_trueData_', rep)]] <- mat_rf
    mats[[paste0("bRF_", as.character(alpha),  '_shuffled_', rep)]] <- mat_rf_perm
    
  }
}
save(mats, file = "results/100_permutations_bRF_importances_inf.rdata")

# thresholds the regulatory weights at a certain density to build GRNs
edges <- list()
densities <- c(0.005, 0.01,0.05)
for(name in names(mats)){ 
      for(density in densities){# exploring importance threshold stringency
        edges[[paste0(name, '_', density)]] <-
        weightedRF_network(mats[[name]], density = density, pwm_occurrence, genes, tfs)
      }
}
save(edges, file = "results/100_permutations_bRF_edges_inf.rdata")

# validation of GRN edges against DAP-Seq
settings <- c("model", "alpha", "dataset", "rep", "density")
val_dap <-
  evaluate_networks(
    edges,
    validation = c("DAPSeq"),
    nCores = 35,
    input_genes = genes,
    input_tfs = tfs
  )
val_dap[, settings] <-
  str_split_fixed(val_dap$network_name, '_', length(settings))
save(val_dap, file = "results/100_permutations_rf_validation_dap_inf.rdata")

For weightedLASSO

nCores = 40
mats <- list()
nrep <- 50
for(alpha in ALPHAS){ # exploring PWM integration strength
  for(rep in 1:nrep){ # exploring inherent variability
    
    mat_lasso <- weightedLASSO_inference(counts, genes, tfs,
                                     alpha = alpha, N = 10, 
                                     tf_expression_permutation = FALSE,
                                     int_pwm_noise = 0, mda_type = "shuffle",
                                     pwm_occurrence = pwm_occurrence,
                                     nCores = nCores)
    
    mat_lasso_perm <- weightedLASSO_inference(counts, genes, tfs,
                                     alpha = alpha, N = 50, 
                                     tf_expression_permutation = TRUE,
                                     int_pwm_noise = 0, mda_type = "shuffle",
                                     pwm_occurrence = pwm_occurrence,
                                     nCores = nCores)
    
    mats[[paste0("LASSO_", as.character(alpha),  '_trueData_', rep)]] <- mat_lasso
    mats[[paste0("LASSO_", as.character(alpha),  '_shuffled_', rep)]] <- mat_lasso_perm
    
  }
}
save(mats, file = "results/100_permutations_lasso_mda_shuffle.rdata")

# thresholds the regulatory weights at certain densities to build GRNs
edges <- list()
lmses <- data.frame(row.names = genes)
densities <- c(0.005, 0.01,0.05)
for(name in names(mats)){ 
      for(density in densities){# exploring importance threshold stringency
        edges[[paste0(name, '_', density)]] <-
        weightedLASSO_network(mats[[name]], density = density, pwm_occurrence, 
                          genes, tfs, decreasing = TRUE)
        
        lmses[,name] <- mats[[name]]["mse",]
      }
  
}
save(edges, file = "results/100_permutations_lasso_edges.rdata")
save(lmses, file = "results/lasso_perumtations_mse_all_genes_test.rdata")


# validation of GRN edges against DAP-Seq
settings <- c("model", "alpha", "dataset", "rep", "density")
val_dap <-
  evaluate_networks(
    edges,
    validation = c("DAPSeq"),
    nCores = 35,
    input_genes = genes,
    input_tfs = tfs
  )
val_dap[, settings] <-
  str_split_fixed(val_dap$network_name, '_', length(settings))
save(val_dap, file = "results/100_permutations_lasso_validation_dap.rdata")

Showing the MSE and effective data integration behaviors

# prior values
pwm_imputed <- pwm_occurrence
pwm_imputed[is.na(pwm_imputed)] <- 0.5

# parallel sapply to parallelise the computing of optimal alphas
mcsapply <- function (X, FUN, ..., simplify = TRUE, USE.NAMES = TRUE) {
  FUN <- match.fun(FUN)
  answer <- parallel::mclapply(X = X, FUN = FUN, ...)
  if (USE.NAMES && is.character(X) && is.null(names(answer))) 
    names(answer) <- X
  if (!isFALSE(simplify) && length(answer)) 
    simplify2array(answer, higher = (simplify == "array"))
  else answer
}

weightedRF

# importance metrics and MSE for weightedRF
load("results/brf_perumtations_mse_all_genes_predict.rdata")
load("results/100_permutations_bRF_importances_inf.rdata")


# needs an appropriate mats variable and a lmses variable to be loaded

# plots the importance of TFs that have a specific prior value, here one,
# as alpha is increased
draw_gene_effective_integration <- function(gene, prior=1, type = "rank"){
  tfs_with_motif <- names(which(pwm_imputed[gene,]== prior))
  if(type == "rank")
    data <- data.frame(lapply(mats, function(mat){mean(rank(mat[,gene])[tfs_with_motif])}))
  if(type == "imp")
    data <- data.frame(lapply(mats, function(mat){mean(mat[,gene][tfs_with_motif])}))
  plot <- data %>%
  gather(key = "setting", value = "summed_importance") %>%
  separate(setting, into = c("method", "alpha", "dataset", "rep"), sep = "_") %>%
  group_by(alpha, dataset) %>%
  mutate(mean_imp = mean(summed_importance),
            sd_imp = sd(summed_importance),
         alpha = as.numeric(alpha))%>%
  ggplot(aes(x=alpha, y = summed_importance, color = dataset, fill = dataset)) +
  geom_point(alpha = 0.1) + geom_line(aes(y=mean_imp))+
  geom_ribbon(aes(ymin = mean_imp - sd_imp , 
                    ymax = mean_imp + sd_imp  ), 
                alpha = .4)  +theme_pubr(legend = "none")+
    ylab("Rank of TFs with Pi=1")+
    ggtitle("Effective data integration") + 
    labs(subtitle = gene)+
    scale_color_manual(values = setNames(c("grey", "#70AD47"), c("shuffled", "trueData")))+
    scale_fill_manual(values = setNames(c("grey", "#70AD47"), c("shuffled", "trueData")))
  plot + xlab(expression(alpha))
}


# plots the MSE as alpha is increased
draw_gene_mse <- function(gene, title = NULL){
  lmses[gene, ] %>%
    gather() %>%
    separate(key,
             into = c("model", "alpha", "dataset",  "rep"),
             sep = "_") %>%
    group_by(alpha, dataset) %>%
    mutate(mean_mse = mean(value, na.rm = T),
           sd_mse = sd(value, na.rm = T)) %>%
    ggplot(aes(
      x = as.numeric(alpha),
      y = value,
      color = dataset,
      fill = dataset
    )) +ggtitle(paste("MSE"))+ylab("MSE/Var(gene)")+
    geom_ribbon(aes(ymin = mean_mse - sd_mse , 
                    ymax = mean_mse + sd_mse  ), 
                alpha = .4)  +theme_pubr(legend = "top")+
    geom_point(alpha = 0.1) + geom_line(aes(y=mean_mse))+xlab(expression(alpha))+ 
    scale_color_manual(values = setNames(c("grey", "#70AD47"), c("shuffled", "trueData")))+
    scale_fill_manual(values = setNames(c("grey", "#70AD47"), c("shuffled", "trueData")))
}


# plots the MSE as a function of effective data integration
get_opt_alpha_per_gene <- function(gene, type = "rank", return_cluster = F){
  tfs_with_motif <- names(which(pwm_occurrence[gene,]==1))
  # gene with no TFBS has an optimal alpha of 0
  if(return_cluster & length(tfs_with_motif)==0) return(0)
  # else
  if(length(tfs_with_motif)>0){
    
  if(type == "rank")
    data <- data.frame(lapply(mats, function(mat){mean(rank(mat[,gene])[tfs_with_motif])}))
  if(type == "imp")
    data <- data.frame(lapply(mats, function(mat){mean(mat[,gene][tfs_with_motif])}))
  
  inte <- data %>%
  gather(key = "setting", value = "summed_importance") %>%
  separate(setting, into = c("method", "alpha", "dataset", "rep"), sep = "_") %>%
  group_by(alpha, dataset) %>%
  mutate(mean_imp = mean(summed_importance, na.rm=T),
            sd_imp = sd(summed_importance, na.rm=T),
         alpha = as.numeric(alpha))
  
  curves <- lmses[gene, ] %>%
    gather() %>%
   separate(key,
             into = c("model", "alpha", "dataset",  "rep"),
             sep = "_") %>%
    group_by(alpha, dataset) %>%
    mutate(mean_mse = mean(value, na.rm = T),
           sd_mse = sd(value, na.rm = T)) %>%
    mutate(alpha = as.numeric(alpha))%>%
    full_join(inte, by = c("alpha", "dataset", "rep")) %>%
    group_by(alpha, mean_imp, dataset) %>%
    summarise(mean_mse = mean(value, na.rm = T),
           sd_mse = sd(value, na.rm = T)) 
  curves <- curves %>%
  group_by(dataset) %>%
    arrange(dataset, mean_imp)%>%
      mutate(imps=curves[curves$dataset=="trueData", ]$mean_imp) %>%
      mutate(approx_mse = approx(mean_imp,mean_mse,curves[curves$dataset=="trueData", ]$mean_imp, rule=2)$y,
             approx_sd = approx(mean_imp,sd_mse,curves[curves$dataset=="trueData", ]$mean_imp, rule=2)$y) 
  
  
  true <- curves[curves$dataset=="trueData",]
  shuff <- curves[curves$dataset!="trueData",]
  
  # true$div <- (shuff$approx_mse - shuff$approx_sd) - (true$mean_mse)
  true$div <- ifelse((shuff$approx_mse - true$mean_mse)/shuff$approx_sd>1, 
                     (shuff$approx_mse - true$mean_mse)/shuff$approx_sd, 
                     0)
  if(max(true$div)>0) alpha_opt <- true[true$div == max(true$div),]$alpha
  else alpha_opt <- 0
  
  eff_opt = true[true$alpha==alpha_opt,]$mean_imp
  padding = ifelse(type=="rank", ifelse(alpha_opt<0.2, 17, -17), 0)
  if(return_cluster) return(alpha_opt)
  
  curves%>%
    ggplot(aes(y=mean_mse, x = mean_imp, color = dataset, fill = dataset))+
  # geom_ribbon(aes(ymin =mean_mse-sd_mse , 
  #                 ymax = mean_mse + sd_mse),  alpha = .4)+ 
    geom_ribbon(aes(x=imps,ymin =approx_mse-approx_sd ,
                  ymax = approx_mse + approx_sd), alpha = .25)+
    geom_point(alpha = 0.1, size = 0.5) +
    geom_line(aes(x=mean_imp, y = mean_mse), size=1) + 
      # geom_line(aes(x=imps, y = approx_mse), col="black")+
  theme_pubr(legend = "none") +
    ylab("MSE/Var(gene)") + xlab("Effective data integration") + ggtitle("MSE=f(effective integration)")+ 
    scale_color_manual(values = setNames(c("grey", "#70AD47"), c("shuffled", "trueData")))+
    scale_fill_manual(values = setNames(c("grey", "#70AD47"), c("shuffled", "trueData")))+
    geom_vline(xintercept = eff_opt, size = 2, col="#4670CD") +
  annotate("text", x=eff_opt+padding, y=max(shuff$mean_mse), 
           label=paste("Optimal\nintegration:\nalpha =", alpha_opt), 
           angle=0, col = "#4670CD", size =3.5 )
  }
  else ggplot()
}


gene_no_benefit <- "AT1G30270"
gene_benefit_optimum <- "AT1G14720"
gene_benefit <- "AT1G30080"

n <- draw_gene_effective_integration(gene_no_benefit, prior = 1) +
  draw_gene_mse(gene_no_benefit)+
  get_opt_alpha_per_gene(gene_no_benefit)+
  plot_annotation(title = gene_no_benefit) 
o <- draw_gene_effective_integration(gene_benefit_optimum, prior = 1) +
  draw_gene_mse(gene_benefit_optimum)+theme(legend.position = "none")+
  get_opt_alpha_per_gene(gene_benefit_optimum)+
  plot_annotation(title = gene_benefit_optimum) 
p <- draw_gene_effective_integration(gene_benefit, prior = 1) +
  draw_gene_mse(gene_benefit)+theme(legend.position = "none")+
  get_opt_alpha_per_gene(gene_benefit)+
  plot_annotation(title = gene_benefit)


figure <- n/o/p;figure

ggexport(figure, filename = "results/gene_examples_weightedRF.pdf", width = 10, height = 9)


# value of alpha per genes:
alphas_rf <- mcsapply(genes, get_opt_alpha_per_gene, type = "rank", 
                          return_cluster=T, mc.cores=34)
hist(alphas_rf)

save(alphas_rf, file = "results/alpha_per_gene_weighted_RF.rdata")
#turns old mse of RF into new format returned by inference
colnames(lmses) <- 
  paste0("RF", "_", str_split_fixed(colnames(lmses), ' ', 3)[,1], '_', 
       str_split_fixed(colnames(lmses), ' ', 3)[,3], "_", 
       str_split_fixed(colnames(lmses), ' ', 3)[,2]) %>%
  str_replace("_data", "Data")
save(lmses, file = "results/brf_perumtations_mse_all_genes_predict.rdata")

Doing the same for weightedLASSO:

# for the lasso
load("results/lasso_perumtations_mse_all_genes_test.rdata")
load("results/100_permutations_lasso_mda_shuffle.rdata")
# load("results/100_permutations_lasso_edges.rdata")

gene_no_benefit <- "AT5G66850"
# the same as RFs
gene_benefit_optimum <- "AT1G14720"
gene_benefit <- "AT2G38180"


n <- draw_gene_effective_integration(gene_no_benefit, prior = 1) +
  draw_gene_mse(gene_no_benefit)+
  get_opt_alpha_per_gene(gene_no_benefit)+
  plot_annotation(title = gene_no_benefit) 
o <- draw_gene_effective_integration(gene_benefit_optimum, prior = 1) +
  draw_gene_mse(gene_benefit_optimum)+theme(legend.position = "none")+
  get_opt_alpha_per_gene(gene_benefit_optimum)+
  plot_annotation(title = gene_benefit_optimum) 
p <- draw_gene_effective_integration(gene_benefit, prior = 1) +
  draw_gene_mse(gene_benefit)+theme(legend.position = "none")+
  get_opt_alpha_per_gene(gene_benefit)+
  plot_annotation(title = gene_benefit)


n/o/p

figure <- n/o/p
ggexport(figure, filename = "results/gene_examples_weightedLASSO.pdf", width = 10, height = 9)


# 
# nrt <- sample(int, size = 1)
# draw_gene_effective_integration(nrt, prior = 1, type = "imp")
# draw_gene_effective_integration(nrt, prior = 0.5, type = "imp")
# draw_gene_effective_integration(nrt, prior = 0, type = "imp")
# get_opt_alpha_per_gene(nrt, type = "imp")





# value of alpha per genes:
alphas_lasso <- mcsapply(genes, get_opt_alpha_per_gene, type = "rank", 
                          return_cluster=T, mc.cores=34)
hist(alphas_lasso)

save(alphas_lasso, file = "results/alpha_per_gene_weighted_LASSO_test.rdata")





### test of specific alphas
# mat_lasso <- weightedLASSO_inference(counts, genes, tfs,
#                                      alpha = alphas_lasso, N = 50, 
#                                      tf_expression_permutation = FALSE,
#                                      int_pwm_noise = 0, mda_type = "shuffle",
#                                      pwm_occurrence = pwm_occurrence,
#                                      nCores = nCores)
# 
# mat_lasso_perm <- weightedLASSO_inference(counts, genes, tfs,
#                                  alpha = alphas_lasso, N = 50, 
#                                  tf_expression_permutation = TRUE,
#                                  int_pwm_noise = 0, mda_type = "shuffle",
#                                  pwm_occurrence = pwm_occurrence,
#                                  nCores = nCores)
# 
# 
# grn <- weightedLASSO_network(mat_lasso, density = 0.005, pwm_occurrence,
#                           genes, tfs, decreasing = TRUE)
# grn_perm <- weightedLASSO_network(mat_lasso_perm, density = 0.005, pwm_occurrence,
#                           genes, tfs, decreasing = TRUE)
# 
# 
# median(mats$LASSO_1_trueData_1["mse",])
# median(mat_lasso["mse",])
# median(mat_lasso_perm["mse",])
# median(mats$LASSO_0_trueData_10["mse",])
# 
# mean(mats$LASSO_1_trueData_1["mse",])
# mean(mat_lasso_perm["mse",])
# mean(mats$LASSO_0_trueData_10["mse",])
# 
# mean(grn$pwm)
# evaluate_network(grn, genes, tfs, validation = "DAPSeq")[c("tpr", "recall")]
# evaluate_network(grn_perm, genes, tfs, validation = "DAPSeq")[c("tpr", "recall")]
########## tests ugly
lasso_1<- names(alphas_lasso[alphas_lasso==0.9])
for(nrt in sample(lasso_1, size = 10, replace = F)){
  print(draw_gene_effective_integration(nrt, prior = 1) +
  draw_gene_mse(nrt)+theme(legend.position = "none")+
  get_opt_alpha_per_gene(nrt)+
  plot_annotation(title = nrt))
}

lasso_1<- names(alphas_lasso[alphas_lasso==1])
for(nrt in sample(lasso_1, size = 5, replace = F)){
  print(draw_gene_effective_integration(nrt, prior = 1) +
  draw_gene_mse(nrt)+theme(legend.position = "none")+
  get_opt_alpha_per_gene(nrt)+
  plot_annotation(title = nrt))
}

Plotting MSE behaviour for all genes

Depending on their class (optimal alpha different from 0 or not).

weightedRF

load("results/brf_perumtations_mse_all_genes_predict.rdata")
load(file = "results/alpha_per_gene_weighted_RF.rdata")

pos_class <- names(alphas_rf[alphas_rf!=0])

mse <- lmses[str_detect(colnames(lmses), "trueData")]
for(alpha in seq(0,1, by = 0.1)){
  mse[,paste("alpha",alpha)] <- rowMeans(mse[,str_detect(colnames(mse), paste0("RF_",as.character(alpha), "_"))])
}
mse <- as.matrix(mse[str_detect(colnames(mse), "alpha")])
mse_interest <- mse[pos_class,]
mse_other <- mse[setdiff(rownames(mse), pos_class),]

library(circlize)
col_fun = colorRamp2(c(-2, 0, 2), hcl_palette = "Blue-Red 3")

ha = HeatmapAnnotation(
    alpha = anno_simple(as.numeric(str_remove(colnames(mse), "alpha "))),
    annotation_name_side = "left")

pdf("results/rf_mse_interest.pdf", height = 5)
Heatmap((mse_interest-rowMeans(mse_interest))/matrixStats::rowSds(mse_interest), 
        col = col_fun, show_column_names = F,
         width = ncol(mse_interest)*unit(10, "mm"), 
      height = nrow(mse_interest)*unit(0.2, "mm"),
        cluster_columns = F, show_row_names = F, top_annotation = ha)+ 
  rowAnnotation(class = ifelse(alphas_rf[rownames(mse_interest)]>0, 
                               "class of interest", "no integration"), 
                col = list(class = setNames(c("#70AD47", "grey"), 
                                            nm = c("class of interest", "no integration"))))
dev.off()
## png 
##   2
pdf("results/rf_mse_others.pdf", height = 10)
Heatmap((mse_other-rowMeans(mse_other))/matrixStats::rowSds(mse_other),
        col = col_fun,show_column_names = F,
        width = ncol(mse_other)*unit(10, "mm"), 
      height = nrow(mse_other)*unit(0.2, "mm"),
        cluster_columns = F, show_row_names = F, top_annotation = ha)+ 
  rowAnnotation(class = ifelse(alphas_rf[rownames(mse_other)]>0, 
                               "class of interest", "no integration"), 
                col = list(class = setNames(c("#70AD47", "grey"), 
                                            nm = c("class of interest", "no integration"))))
dev.off()
## png 
##   2

For the lasso

load("results/lasso_perumtations_mse_all_genes_test.rdata")
load("results/alpha_per_gene_weighted_LASSO_test.rdata")

pos_class <- names(alphas_lasso[alphas_lasso!=0])

mse <- lmses[str_detect(colnames(lmses), "trueData")]
for(alpha in seq(0,1, by = 0.1)){
  mse[,paste("alpha",alpha)] <- rowMeans(mse[,str_detect(colnames(mse), paste0("LASSO_",as.character(alpha), "_"))])
}
mse <- as.matrix(mse[str_detect(colnames(mse), "alpha")])
mse_interest <- mse[pos_class,]
mse_other <- mse[setdiff(rownames(mse), pos_class),]

library(circlize)
col_fun = colorRamp2(c(-2, 0, 2), hcl_palette = "Blue-Red 3")

ha = HeatmapAnnotation(
    alpha = anno_simple(as.numeric(str_remove(colnames(mse), "alpha "))),
    annotation_name_side = "left")

pdf("results/lasso_mse_interest.pdf", height = 7)
Heatmap((mse_interest-rowMeans(mse_interest))/matrixStats::rowSds(mse_interest), 
        col = col_fun, show_column_names = F,
         width = ncol(mse_interest)*unit(10, "mm"), 
      height = nrow(mse_interest)*unit(0.2, "mm"),
        cluster_columns = F, show_row_names = F, top_annotation = ha)+ 
  rowAnnotation(class = ifelse(alphas_lasso[rownames(mse_interest)]>0, 
                               "class of interest", "no integration"), 
                col = list(class = setNames(c("#70AD47", "grey"), 
                                            nm = c("class of interest", "no integration"))))
dev.off()
## png 
##   2
pdf("results/lasso_mse_others.pdf", height = 7)

Heatmap((mse_other-rowMeans(mse_other))/matrixStats::rowSds(mse_other),
        col = col_fun,show_column_names = F,
        width = ncol(mse_other)*unit(10, "mm"), 
      height = nrow(mse_other)*unit(0.2, "mm"),
        cluster_columns = F, show_row_names = F, top_annotation = ha)+ 
  rowAnnotation(class = ifelse(alphas_lasso[rownames(mse_other)]>0, 
                               "class of interest", "no integration"), 
                col = list(class = setNames(c("#70AD47", "grey"), 
                                            nm = c("class of interest", "no integration"))))
dev.off()
## png 
##   2
# 
# 
# Heatmap((mse_interest-rowMeans(mse_interest))/matrixStats::rowSds(mse_interest), 
#         col = col_fun, show_column_names = F,
#          width = ncol(mse_interest)*unit(10, "mm"), 
#       height = nrow(mse_interest)*unit(0.2, "mm"),
#         cluster_columns = F, show_row_names = F, top_annotation = ha)+ 
#   rowAnnotation(class = ifelse(alphas_lasso[rownames(mse_interest)]>0, 
#                                "class of interest", "no integration"), 
#                 col = list(class = setNames(c("#70AD47", "grey"), 
#                                             nm = c("class of interest", "no integration"))))+
#   Heatmap(mse_interest, 
#         col = col_fun, show_column_names = F,
#          width = ncol(mse_interest)*unit(10, "mm"), 
#       height = nrow(mse_interest)*unit(0.2, "mm"),
#         cluster_columns = F, show_row_names = F, top_annotation = ha)+ 
#   rowAnnotation(class = ifelse(alphas_lasso[rownames(mse_interest)]>0, 
#                                "class of interest", "no integration"), 
#                 col = list(class = setNames(c("#70AD47", "grey"), 
#                                             nm = c("class of interest", "no integration"))))
# 
# Heatmap((mse_other-rowMeans(mse_other))/matrixStats::rowSds(mse_other),
#         col = col_fun,show_column_names = F,
#         width = ncol(mse_other)*unit(10, "mm"), 
#       height = nrow(mse_other)*unit(0.2, "mm"),
#         cluster_columns = F, show_row_names = F, top_annotation = ha)+ 
#   rowAnnotation(class = ifelse(alphas_lasso[rownames(mse_other)]>0, 
#                                "class of interest", "no integration"), 
#                 col = list(class = setNames(c("#70AD47", "grey"), 
#                                             nm = c("class of interest", "no integration"))))+
#   Heatmap(mse_other,
#         col = col_fun,show_column_names = F,
#         width = ncol(mse_other)*unit(10, "mm"), 
#       height = nrow(mse_other)*unit(0.2, "mm"),
#         cluster_columns = F, show_row_names = F, top_annotation = ha)+ 
#   rowAnnotation(class = ifelse(alphas_lasso[rownames(mse_other)]>0, 
#                                "class of interest", "no integration"), 
#                 col = list(class = setNames(c("#70AD47", "grey"), 
#                                             nm = c("class of interest", "no integration"))))

Intersection between classes

venn <- ggVennDiagram(list("weightedRF" = names(alphas_rf[alphas_rf!=0]),
                   "weightedLASSO" = names(alphas_lasso[alphas_lasso!=0])), color="black")+ 
  scale_fill_gradient(low = "#EEEEEE", high = "#4670CD")+scale_color_manual(values = c("grey", "grey"))+
  theme(legend.position = "none");venn

# hypergeometric test to assess the significance of the intersect
p_enrich <- phyper(q=220, m = length(names(alphas_rf[alphas_rf!=0])), 
                   n = length(genes) - length(names(alphas_rf[alphas_rf!=0])), 
                   k = length( names(alphas_lasso[alphas_lasso!=0])), lower.tail = F)
p_enrich
## [1] 5.846883e-06
ggexport(venn, filename = "results/classes_intersection.pdf", width = 4, height = 4)

Motif enrichment in gene lists

load("rdata/pwm_prom_jaspar_dap.rdata")

pos_class_rf <- names(alphas_rf[alphas_rf!=0])
pos_class_lasso <- names(alphas_lasso[alphas_lasso!=0])
common <- intersect(pos_class_lasso, pos_class_rf)

known_tfs <- tfs[which(tfs %in% pwm_prom$TF)]

get_number_of_motifs_per_tfs <- function(genes){
  table(pwm_prom[pwm_prom$target %in% genes & pwm_prom$TF %in% tfs,"TF"])[known_tfs]
}


gene_list <- pos_class_rf
enrichments_per_pwm <- setNames(rep(1, length(known_tfs)), known_tfs)
n_targets_lasso_in_all <- get_number_of_motifs_per_tfs(genes)
n_targets_lasso_in_group <- get_number_of_motifs_per_tfs(gene_list)
n_group_lasso <- length(gene_list)

for(tf in known_tfs){
  
  # number of genes with that motif in all genes
  n_targets_in_all_tf <- n_targets_lasso_in_all[tf]
  
  # number of genes with that motif in the lasso group
  n_targets_lasso_in_group_tf <- n_targets_lasso_in_group[tf]
  p_lasso <- phyper(q=n_targets_lasso_in_group_tf-1,
         m=n_targets_in_all_tf, #white balls
         n=length(genes)-n_targets_in_all_tf, # black balls
         k=n_group_lasso, lower.tail = FALSE)
  
  
  enrichments_per_pwm[tf]<- p_lasso
}
enriched_tfs <- names(which(enrichments_per_pwm < 0.05))
DIANE::get_gene_information(enriched_tfs, organism = "Arabidopsis thaliana")
gene_list <- pos_class_lasso
enrichments_per_pwm <- setNames(rep(1, length(known_tfs)), known_tfs)
n_targets_lasso_in_all <- get_number_of_motifs_per_tfs(genes)
n_targets_lasso_in_group <- get_number_of_motifs_per_tfs(gene_list)
n_group_lasso <- length(gene_list)

for(tf in known_tfs){
  
  # number of genes with that motif in all genes
  n_targets_in_all_tf <- n_targets_lasso_in_all[tf]
  
  # number of genes with that motif in the lasso group
  n_targets_lasso_in_group_tf <- n_targets_lasso_in_group[tf]
  p_lasso <- phyper(q=n_targets_lasso_in_group_tf-1,
         m=n_targets_in_all_tf, #white balls
         n=length(genes)-n_targets_in_all_tf, # black balls
         k=n_group_lasso, lower.tail = FALSE)
  
  
  enrichments_per_pwm[tf]<- p_lasso
}
enriched_tfs <- names(which(enrichments_per_pwm < 0.05))
DIANE::get_gene_information(enriched_tfs, organism = "Arabidopsis thaliana")
gene_list <- common
enrichments_per_pwm <- setNames(rep(1, length(known_tfs)), known_tfs)
n_targets_lasso_in_all <- get_number_of_motifs_per_tfs(genes)
n_targets_lasso_in_group <- get_number_of_motifs_per_tfs(gene_list)
n_group_lasso <- length(gene_list)

for(tf in known_tfs){
  
  # number of genes with that motif in all genes
  n_targets_in_all_tf <- n_targets_lasso_in_all[tf]
  
  # number of genes with that motif in the lasso group
  n_targets_lasso_in_group_tf <- n_targets_lasso_in_group[tf]
  p_lasso <- phyper(q=n_targets_lasso_in_group_tf-1,
         m=n_targets_in_all_tf, #white balls
         n=length(genes)-n_targets_in_all_tf, # black balls
         k=n_group_lasso, lower.tail = FALSE)
  
  
  enrichments_per_pwm[tf]<- p_lasso
}
enriched_tfs <- names(which(enrichments_per_pwm < 0.05))
DIANE::get_gene_information(enriched_tfs, organism = "Arabidopsis thaliana")

precision curves

draw_validation <- function(validation, densities = c(0.005,0.01,0.05), precision_only = F){
  data_val <- validation %>%
    filter(density %in% densities)%>%
    group_by(model, alpha, dataset, density) %>%
    mutate(mean_precision = mean(precision, na.rm = T),
           sd_precision = sd(precision, na.rm = T),
           mean_recall = mean(recall, na.rm = T),
           sd_recall = sd(recall, na.rm = T),
           density = paste("D =", density),
           model = str_replace(model, "bRF", "weightedRF"),
           model = str_replace(model, "LASSO", "weightedLASSO")) %>%
  dplyr::select(-network_name) %>%
  mutate(alpha = as.numeric(alpha))

precision <- data_val %>%
    ggplot(aes(
      x = as.numeric(alpha),
      y = precision,
      color = dataset,
      fill = dataset
    ))+
    ggh4x::facet_nested_wrap(vars(density), ncol = 8, nest_line = T) + 
    geom_ribbon(aes(ymin = mean_precision - sd_precision , 
                    ymax = mean_precision + sd_precision  ), 
                alpha = .4) + theme_pubr(legend = "top")+
    geom_point(alpha = 0.1) + xlab("alpha") +
  xlab(expression(alpha)) + ylab("Precision") +
  ggtitle(paste("Precision against DAP-Seq")) +
  geom_line(aes(group = dataset, y = mean_precision)) +
  theme(
    strip.background = element_blank(),
    axis.title.x = element_text(size = 22),
    title = element_text(size = 16),
    strip.text = element_text(size = 16),
    legend.text = element_text(size = 15),
    axis.text = element_text(size = 12),
    legend.position = 'top'
  )+scale_color_manual(values = setNames(c("grey", "#70AD47"), c("shuffled", "trueData")))+
    scale_fill_manual(values = setNames(c("grey", "#70AD47"), c("shuffled", "trueData"))) +
  geom_hline(color = "#C55A11", yintercept = 0.331042, size= 1.5, show.legend = T)

recall <- data_val %>%
    ggplot(aes(
      x = as.numeric(alpha),
      y = recall,
      color = dataset,
      fill = dataset
    ))+
    ggh4x::facet_nested_wrap(vars(density), ncol = 8, nest_line = T, scales = "free") + 
    geom_ribbon(aes(ymin = mean_recall - sd_recall , 
                    ymax = mean_recall + sd_recall  ), 
                alpha = .4)  +theme_pubr(legend = "top")+
    geom_point(alpha = 0.1) + xlab("alpha") +
  xlab(expression(alpha)) + ylab("Recall") +
  geom_line(aes(group = dataset, y = mean_recall)) +
  ggtitle(paste("Recall against DAPSeq")) +
  theme(
    strip.background = element_blank(),
    axis.title.x = element_text(size = 22),
    title = element_text(size = 16),
    strip.text = element_text(size = 16),
    legend.text = element_text(size = 15),
    axis.text = element_text(size = 12),
    legend.position = 'top'
  )+scale_color_manual(values = setNames(c("grey", "#70AD47"), c("shuffled", "trueData")))+
    scale_fill_manual(values = setNames(c("grey", "#70AD47"), c("shuffled", "trueData")))

if(precision_only)
  return(precision)
else
  return(precision/recall)
}
load(file = "results/100_permutations_bRF_validation_dap_inf.rdata")
val_rf <- val_dap
load(file = "results/100_permutations_lasso_validation_dap.rdata")
val_lasso <- val_dap

draw_validation(validation = val_lasso)

draw_validation(validation = val_rf)

Gene specific GRNs

load("results/alpha_per_gene_weighted_LASSO_test.rdata")
load("results/alpha_per_gene_weighted_RF.rdata")
hist(alphas_lasso)
hist(alphas_rf)

nCores = 45
mats <- list()
nrep <- 10
for(rep in 1:nrep){ # exploring inherent variability
  
  mat_lasso <- weightedLASSO_inference(counts, genes, tfs,
                                   alpha = alphas_lasso, N = 50,
                                   tf_expression_permutation = FALSE,
                                   int_pwm_noise = 0, mda_type = "shuffle",
                                   pwm_occurrence = pwm_occurrence,
                                   nCores = nCores)

  mat_lasso_perm <- weightedLASSO_inference(counts, genes, tfs,
                                   alpha = alphas_lasso, N = 50,
                                   tf_expression_permutation = TRUE,
                                   int_pwm_noise = 0, mda_type = "shuffle",
                                   pwm_occurrence = pwm_occurrence,
                                   nCores = nCores)

  mats[[paste0("LASSO_trueData_", rep)]] <- mat_lasso
  mats[[paste0("LASSO_shuffled_", rep)]] <- mat_lasso_perm
  
   mat_rf <- weightedRF_inference(counts, genes, tfs, nTrees = 2000,
                            alpha = alphas_rf,
                            pwm_occurrence = pwm_occurrence,
                            nCores = nCores,
                            importance = "%IncMSE")
    
    mat_rf_perm <- weightedRF_inference(counts, genes, tfs, nTrees = 2000,
                            alpha = alphas_rf, 
                            tf_expression_permutation = TRUE,
                            pwm_occurrence = pwm_occurrence,
                            nCores = nCores,
                            importance = "%IncMSE")
  
  mats[[paste0("RF_trueData_", rep)]] <- mat_rf
  mats[[paste0("RF_shuffled_", rep)]] <- mat_rf_perm
}
save(mats, file = "results/gene_specific_grns.rdata")


edges <- list()
lmses <- data.frame(row.names = genes)
for(name in names(mats)){ 
      for(density in c(0.005,0.01,0.05)){# exploring importance threshold stringency
        
        if(str_detect(name, "LASSO"))
          edges[[paste0(name, '_', density)]] <-
        weightedLASSO_network(mats[[name]], density = density, pwm_occurrence, 
                          genes, tfs, decreasing = TRUE)
        else
          edges[[paste0(name, '_', density)]] <-
        weightedRF_network(mats[[name]], density = density, pwm_occurrence, 
                          genes, tfs)
        lmses[,name] <- mats[[name]]["mse",]
      }
}
save(lmses, file = "results/gene_specific_mse.rdata") 
save(edges, file = "results/gene_specific_edges.rdata") 
load("results/gene_specific_grns.rdata")
load("results/gene_specific_mse.rdata")
load("results/gene_specific_edges.rdata")

settings <- c("model", "dataset", "rep", "density")
val_specific <-
  evaluate_networks(
    edges,
    validation = c("DAPSeq"),
    nCores = 30,
    input_genes = genes,
    input_tfs = tfs
  )
## 13.695 sec elapsed
val_specific[, settings] <-
  str_split_fixed(val_specific$network_name, '_', length(settings))
save(val_specific, file = "results/gene_specific_validation.rdata")

precision_rf <- draw_validation(validation = val_rf, 
                                   densities = c(0.005), precision_only = T)
precision_lasso <- draw_validation(validation = val_lasso, 
                                   densities = c(0.005), precision_only = T)


spec_rf <- val_specific %>%
  separate(network_name, into = settings, sep = "_") %>%
    filter(density == 0.005 & model == "RF") %>%
  mutate(density = paste("D =", density),
           model = str_replace(model, "RF", "weightedRF"),
           model = str_replace(model, "LASSO", "weightedLASSO"))%>%
  ggplot(aes(x=dataset, y = precision, color = dataset, fill = dataset)) + 
  scale_color_manual(values = setNames(c("grey", "#70AD47"), c("shuffled", "trueData")))+
    scale_fill_manual(values = setNames(c("grey", "#70AD47"), c("shuffled", "trueData"))) +
  geom_hline(color = "#C55A11", yintercept = 0.331042, size= 1.5, show.legend = T) + 
  geom_boxplot(size = 1, alpha = 0.5, outlier.alpha = 0) +
  theme_pubr() + geom_jitter(width = 0.2, size = 2)+
  theme(
    strip.background = element_blank(),
    axis.title.x = element_text(size = 22),
    title = element_text(size = 16),
    strip.text = element_text(size = 16),
    legend.text = element_text(size = 15),
    axis.text = element_blank(),
    legend.position = 'none'
  ) + xlab("") + ylim(c(0.2,0.45))+ ylab("")

spec_lasso <- val_specific %>%
  separate(network_name, into = settings, sep = "_") %>%
    filter(density == 0.005 & model == "LASSO") %>%
  mutate(density = paste("D =", density),
           model = str_replace(model, "RF", "weightedRF"),
           model = str_replace(model, "LASSO", "weightedLASSO"))%>%
  ggplot(aes(x=dataset, y = precision, color = dataset, fill = dataset)) + 
  scale_color_manual(values = setNames(c("grey", "#70AD47"), c("shuffled", "trueData")))+
    scale_fill_manual(values = setNames(c("grey", "#70AD47"), c("shuffled", "trueData"))) +
  geom_hline(color = "#C55A11", yintercept = 0.331042, size= 1.5, show.legend = T) + 
  geom_boxplot(size = 1, alpha = 0.5, outlier.alpha = 0) +
  theme_pubr() + geom_jitter(width = 0.2, size = 2)+
  theme(
    strip.background = element_blank(),
    axis.title.x = element_text(size = 22),
    title = element_text(size = 16),
    strip.text = element_text(size = 16),
    legend.text = element_text(size = 15),
    axis.text = element_blank(),
    legend.position = 'none'
  )+ xlab("")+ ylim(c(0.2,0.45))+ ylab("")
# gene specific MSE
load("results/gene_specific_mse.rdata")
mse_lasso_spec <- lmses %>%
rownames_to_column("gene") %>%
reshape2::melt()%>%
  separate(variable,
           into = c("model", "dataset",  "rep"),
           sep = "_") %>%
filter(model == "LASSO") %>%
group_by(dataset, rep) %>%
summarise(median_MSE = median(value, na.rm=T))%>%
  ggplot(aes(x=dataset, y = median_MSE, color = dataset, fill = dataset)) + 
scale_color_manual(values = setNames(c("grey", "#70AD47"), c("shuffled", "trueData")))+
  scale_fill_manual(values = setNames(c("grey", "#70AD47"), c("shuffled", "trueData"))) +
geom_boxplot(size = 1, alpha = 0.5, outlier.alpha = 0) +
theme_pubr() + geom_jitter(width = 0.2, size = 2)+
theme(
  strip.background = element_blank(),
  axis.title.x = element_text(size = 22),
  title = element_text(size = 16),
  strip.text = element_text(size = 16),
  legend.text = element_text(size = 15),
  axis.text.x = element_blank(),
  legend.position = 'none'
)+ xlab("")+ ylab("")

mse_rf_spec <- lmses %>%
rownames_to_column("gene") %>%
reshape2::melt()%>%
  separate(variable,
           into = c("model", "dataset",  "rep"),
           sep = "_") %>%
filter(model == "RF") %>%
group_by(dataset, rep) %>%
summarise(median_MSE = median(value, na.rm=T))%>%
  ggplot(aes(x=dataset, y = median_MSE, color = dataset, fill = dataset)) + 
scale_color_manual(values = setNames(c("grey", "#70AD47"), c("shuffled", "trueData")))+
  scale_fill_manual(values = setNames(c("grey", "#70AD47"), c("shuffled", "trueData"))) +
geom_boxplot(size = 1, alpha = 0.5, outlier.alpha = 0) +
theme_pubr() + geom_jitter(width = 0.2, size = 2)+
theme(
  strip.background = element_blank(),
  axis.title.x = element_text(size = 22),
  title = element_text(size = 16),
  strip.text = element_text(size = 16),
  legend.text = element_text(size = 15),
  axis.text.x = element_blank(),
  legend.position = 'none'
)+ xlab("")+ ylab("") 


# MSE for global alphas
draw_mse <- function(model){
data <- lmses[as.numeric(str_split_fixed(colnames(lmses), '_', 4)[,4]) <=10] %>%
rownames_to_column("gene") %>%
reshape2::melt()%>%
  separate(variable,
           into = c("model", "alpha", "dataset",  "rep"),
           sep = "_") %>%
filter(model == model) %>%
group_by(dataset, rep, alpha) %>%
summarise(median_MSE = median(value, na.rm=T))%>%
  group_by(alpha, dataset) %>%
  mutate(mean_median_mse = mean(median_MSE),
         sd_median_mse = sd(median_MSE)) %>%
  mutate(alpha = as.numeric(alpha))
data %>%
  ggplot(aes(x=alpha, y = median_MSE, color = dataset, fill = dataset)) + 
scale_color_manual(values = setNames(c("grey", "#70AD47"),
                                     c("shuffled", "trueData")))+
  scale_fill_manual(values = setNames(c("grey", "#70AD47"), 
                                      c("shuffled", "trueData"))) +
  geom_line(aes(y=mean_median_mse, group = dataset)) +
  geom_ribbon(aes(ymin = mean_median_mse - sd_median_mse , 
                    ymax = mean_median_mse + sd_median_mse), 
                alpha = .4)+ 
theme_pubr() + geom_point(width = 0.2, size = 2)+
theme(
  strip.background = element_blank(),
  axis.title.x = element_text(size = 22),
  title = element_text(size = 16),
  strip.text = element_text(size = 16),
  legend.text = element_text(size = 15),
  axis.text.x = element_blank(),
  legend.position = 'none'
)+ xlab("")+ ylab("")
}

load("results/lasso_perumtations_mse_all_genes_test.rdata")
mse_lasso <- draw_mse("LASSO")

load("results/brf_perumtations_mse_all_genes_predict.rdata")
mse_rf <- draw_mse("RF")

mse_lasso + ylim(c(0.195,0.25))  +labs(subtitle = "LASSO")+ ggtitle("MSE") + 
  mse_lasso_spec+ ylim(c(0.195,0.25)) +labs(subtitle = "LASSO")+
  mse_rf +  ylim(c(0.195,0.25)) +labs(subtitle = "RF") +
  mse_rf_spec+ ylim(c(0.195,0.25))  +labs(subtitle = "RF")+
  precision_lasso + ylim(c(0.2,0.45)) + ggtitle("Precision") + labs(subtitle = "LASSO")+
  theme(legend.position = 'none') + 
  spec_lasso + labs(subtitle = "LASSO")+
  precision_rf + labs(subtitle = "RF") +ylim(c(0.2,0.45))+theme(legend.position = 'none')+ggtitle("") + 
  spec_rf +labs(subtitle = "RF") +
plot_layout(guides = "collect", ncol = 4, widths = c(1,1,1,1))